{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "136a4efe-fb99-4311-8679-e0a5b6282755",
   "metadata": {},
   "source": [
    "<table style=\"width:100%\">\n",
    "<tr>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<font size=\"2\">\n",
    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
    "<br>汉化的库: <a href=\"https://github.com/GoatCsu/CN-LLMs-from-scratch.git\">https://github.com/GoatCsu/CN-LLMs-from-scratch.git</a>\n",
    "</font>\n",
    "</td>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
    "</td>\n",
    "</tr>\n",
    "</table>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1910a06-e8a3-40ac-8201-ff70615b1ba4",
   "metadata": {
    "tags": []
   },
   "source": [
    "# 使用 LLaMA 3.1 70B 和 Ollama 生成偏好数据集  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a128651b-f326-4232-a994-42f38b7ed520",
   "metadata": {},
   "source": [
    "- **偏好微调（Preference Finetuning）** 旨在使 **指令微调后的 LLM** 更加符合 **人类偏好**。  \n",
    "- 生成 **偏好微调数据集** 有多种方法：\n",
    "  1. **使用指令微调 LLM 生成多个响应**，并由 **人工根据偏好标准进行排序**。  \n",
    "  2. **使用指令微调 LLM 生成多个响应**，并由 **LLM 根据设定的偏好标准进行排序**。  \n",
    "  3. **使用 LLM 基于特定偏好标准直接生成偏好（Preferred）和非偏好（Dispreferred）响应**。  \n",
    "\n",
    "- **本笔记本采用方法 3**。  \n",
    "- 这里使用 **70B 参数的 LLaMA 3.1-Instruct 模型**（通过 **Ollama** 运行）为 **指令数据集生成偏好标签**。  \n",
    "- **期望的指令数据集格式如下**：\n",
    "\n",
    "### 输入（Input）\n",
    "\n",
    "\n",
    "```json\n",
    "[\n",
    "    {\n",
    "        \"instruction\": \"What is the state capital of California?\",\n",
    "        \"input\": \"\",\n",
    "        \"output\": \"The state capital of California is Sacramento.\",\n",
    "    },\n",
    "    {\n",
    "        \"instruction\": \"Provide a synonym for 'fast'.\",\n",
    "        \"input\": \"\",\n",
    "        \"output\": \"A synonym for 'fast' is 'quick'.\",\n",
    "    },\n",
    "    {\n",
    "        \"instruction\": \"What is the capital of Greece?\",\n",
    "        \"input\": \"\",\n",
    "        \"output\": \"The capital of Greece is Athens.\",\n",
    "\n",
    "    },\n",
    "...\n",
    "]\n",
    "```\n",
    "生成的数据集格式如下，其中 **较礼貌的响应** 被标记为 **`'chosen'`（偏好响应）**，**较不礼貌的响应** 被标记为 **`'rejected'`（非偏好响应）**：\n",
    "\n",
    "\n",
    "```json\n",
    "[\n",
    "    {\n",
    "        \"instruction\": \"What is the state capital of California?\",\n",
    "        \"input\": \"\",\n",
    "        \"output\": \"The state capital of California is Sacramento.\",\n",
    "        \"rejected\": \"Look, the state capital of California is obviously Sacramento.\",\n",
    "        \"chosen\": \"The state capital of California is Sacramento.\"\n",
    "    },\n",
    "    {\n",
    "        \"instruction\": \"Provide a synonym for 'fast'.\",\n",
    "        \"input\": \"\",\n",
    "        \"output\": \"A synonym for 'fast' is 'quick'.\",\n",
    "        \"chosen\": \"A suitable alternative to 'fast' would be 'quick'.\",\n",
    "        \"rejected\": \"A synonym for 'fast' is 'quick'.\"\n",
    "    },\n",
    "    {\n",
    "        \"instruction\": \"What is the capital of Greece?\",\n",
    "        \"input\": \"\",\n",
    "        \"output\": \"The capital of Greece is Athens.\",\n",
    "        \"chosen\": \"I'd be happy to help! The capital of Greece is indeed Athens.\",\n",
    "        \"rejected\": \"The capital of Greece is Athens.\"\n",
    "    },\n",
    "...\n",
    "]\n",
    "```\n",
    "\n",
    "### 输出（Output）\n",
    "\n",
    "- 该代码 **无需 GPU**，在 **RAM 充足的笔记本电脑** 上即可运行。  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "63610acc-db94-437f-8d38-e99dca0299cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tqdm version: 4.66.4\n"
     ]
    }
   ],
   "source": [
    "from importlib.metadata import version\n",
    "\n",
    "pkgs = [\"tqdm\",    # 进度条\n",
    "        ]\n",
    "\n",
    "for p in pkgs:\n",
    "    print(f\"{p} version: {version(p)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bcdcb34-ac75-4f4f-9505-3ce0666c42d5",
   "metadata": {},
   "source": [
    "## 安装 Ollama 并下载 LLaMA 3.1\n",
    "\n",
    "- **Ollama** 是一个用于高效运行 **LLM（大语言模型）** 的应用。  \n",
    "- 它是 **[llama.cpp](https://github.com/ggerganov/llama.cpp)** 的封装，后者采用 **纯 C/C++ 实现 LLM**，以 **最大化推理效率**。  \n",
    "- **请注意**，Ollama **仅用于 LLM 推理（inference）**，**不支持训练或微调（finetuning）**。  \n",
    "- **在运行下方代码前**，请先访问 **[https://ollama.com](https://ollama.com)** 并按照安装指南完成 **Ollama 安装**（例如，点击 **“Download”** 按钮，下载适用于您的操作系统的 Ollama 应用）。  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9558a522-650d-401a-84fc-9fd7b1f39da7",
   "metadata": {},
   "source": [
    "- **对于 macOS 和 Windows 用户**，点击 **下载的 Ollama 应用**，如果系统提示安装 **命令行工具**，请选择 **“是”**。  \n",
    "- **Linux 用户** 可以使用 **Ollama 官网提供的安装命令** 进行安装。  \n",
    "\n",
    "- **通常，在命令行使用 Ollama 之前**，需要 **启动 Ollama 应用** 或 **在终端运行 `ollama serve`**。  \n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/ch7/21.png\">\n",
    "\n",
    "- **确保 Ollama 运行后**，在 **另一个终端窗口** 执行以下命令，尝试 **700 亿参数的 LLaMA 3.1 模型**：  \n",
    "\n",
    "\n",
    "```bash\n",
    "# 70B 模型\n",
    "ollama run llama3.1:70b\n",
    "```\n",
    "\n",
    "\n",
    "The output looks like as follows:\n",
    "\n",
    "```\n",
    "$ ollama run llama3.1:70b\n",
    "pulling manifest\n",
    "pulling aa81b541aae6... 100% ▕████████████████▏ 39 GB\n",
    "pulling 8cf247399e57... 100% ▕████████████████▏ 1.7 KB\n",
    "pulling f1cd752815fc... 100% ▕████████████████▏ 12 KB\n",
    "pulling 56bb8bd477a5... 100% ▕████████████████▏ 96 B\n",
    "pulling 3c1c2d3df5b3... 100% ▕████████████████▏ 486 B\n",
    "verifying sha256 digest\n",
    "writing manifest\n",
    "removing any unused layers\n",
    "success\n",
    "```\n",
    "\n",
    "- **注意**：`llama3.1:70b` 指的是 **指令微调后的 700 亿参数 LLaMA 3.1 模型**。  \n",
    "\n",
    "- **如果您的设备资源有限**，可以选择 **更轻量的 80 亿参数 LLaMA 3.1 模型**，  \n",
    "  **只需将 `llama3.1:70b` 替换为 `llama3.1`**。  \n",
    "\n",
    "- **下载完成后**，您将进入 **命令行交互界面**，可与模型进行对话。  \n",
    "\n",
    "- **尝试输入以下提示**：\"What do llamas eat?\"（羊驼吃什么？），  \n",
    "  预计模型会返回类似如下的输出：  \n",
    "\n",
    "\n",
    "```\n",
    ">>> What do llamas eat?\n",
    "Llamas are ruminant animals, which means they have a four-chambered \n",
    "stomach and eat plants that are high in fiber. In the wild, llamas \n",
    "typically feed on:\n",
    "1. Grasses: They love to graze on various types of grasses, including tall \n",
    "grasses, wheat, oats, and barley.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b5addcb-fc7d-455d-bee9-6cc7a0d684c7",
   "metadata": {},
   "source": [
    "- 输入`/bye`以结束这一节"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dda155ee-cf36-44d3-b634-20ba8e1ca38a",
   "metadata": {},
   "source": [
    "## 使用 Ollama 的 REST API\n",
    "\n",
    "- **另一种与模型交互的方式** 是通过 **REST API** 在 **Python** 中进行调用，具体实现如下。  \n",
    "- **在运行本笔记本中的代码前**，请确保 **Ollama 仍在运行**，可通过以下方式启动：\n",
    "  - 在终端中执行 `ollama serve`\n",
    "  - 使用 **Ollama 应用程序**  \n",
    "\n",
    "- **接下来，运行下方代码单元**，以查询模型并获取响应。  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16642a48-1cab-40d2-af08-ab8c2fbf5876",
   "metadata": {},
   "source": [
    "- 首先，我们使用 **一个简单示例** 调用 API，以确保其 **正常运行**：  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "65b0ba76-1fb1-4306-a7c2-8f3bb637ccdb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet consists of:\n",
      "\n",
      "1. **Grasses**: Various types of grasses, including timothy grass, orchard grass, and brome grass.\n",
      "2. **Hay**: High-quality hay, such as alfalfa or clover hay, is a staple in a llama's diet.\n",
      "3. **Leaves**: Leaves from trees and shrubs, like willow, cottonwood, and mesquite, are also eaten.\n",
      "4. **Fruits and vegetables**: Llamas enjoy fruits like apples, carrots, and sweet potatoes, as well as leafy greens like kale and spinach.\n",
      "5. **Grains**: In moderation, llamas can eat grains like oats, barley, and corn.\n",
      "\n",
      "It's essential to note that llamas have a unique digestive system, with a three-part stomach and a large cecum (a specialized part of the large intestine). This allows them to break down and extract nutrients from plant material more efficiently than many other animals.\n",
      "\n",
      "A typical llama diet might consist of:\n",
      "\n",
      "* 1-2% of their body weight in hay per day\n",
      "* 0.5-1% of their body weight in grains per day (if fed)\n",
      "* Free-choice access to fresh water\n",
      "* Limited amounts of fruits and vegetables as treats\n",
      "\n",
      "It's also important to ensure that llamas have access to a mineral supplement, such as a salt lick or loose minerals, to help maintain optimal health.\n",
      "\n",
      "Remember, every llama is different, and their dietary needs may vary depending on factors like age, size, and activity level. Consult with a veterinarian or experienced llama breeder for specific guidance on feeding your llama.\n"
     ]
    }
   ],
   "source": [
    "import urllib.request\n",
    "import json\n",
    "\n",
    "\n",
    "def query_model(prompt, model=\"llama3.1:70b\", url=\"http://localhost:11434/api/chat\"):\n",
    "    # 创建数据负载作为字典\n",
    "    data = {\n",
    "        \"model\": model,\n",
    "        \"messages\": [\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": prompt\n",
    "            }\n",
    "        ],\n",
    "        \"options\": {\n",
    "            \"seed\": 123,\n",
    "            \"temperature\": 0,\n",
    "        }\n",
    "    }\n",
    "\n",
    "    # 将字典转换为 JSON 格式的字符串并编码为字节\n",
    "    payload = json.dumps(data).encode(\"utf-8\")\n",
    "\n",
    "    # 创建请求对象，设置方法为 POST 并添加必要的头信息\n",
    "    request = urllib.request.Request(url, data=payload, method=\"POST\")\n",
    "    request.add_header(\"Content-Type\", \"application/json\")\n",
    "\n",
    "    # 发送请求并捕获响应\n",
    "    response_data = \"\"\n",
    "    with urllib.request.urlopen(request) as response:\n",
    "        # 读取并解码响应\n",
    "        while True:\n",
    "            line = response.readline().decode(\"utf-8\")\n",
    "            if not line:\n",
    "                break\n",
    "            response_json = json.loads(line)\n",
    "            response_data += response_json[\"message\"][\"content\"]\n",
    "\n",
    "    return response_data\n",
    "\n",
    "\n",
    "result = query_model(\"What do Llamas eat?\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "162a4739-6f03-4092-a5c2-f57a0b6a4c4d",
   "metadata": {},
   "source": [
    "## 加载 JSON 数据（Load JSON Entries）\n",
    "\n",
    "- 现在，我们进入 **数据生成** 部分。  \n",
    "- **为了直观演示**，我们将使用 **`instruction-data.json`** 文件，  \n",
    "  该文件最初用于 **第 7 章的指令微调（Instruction Finetuning）**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8b2d393a-aa92-4190-9d44-44326a6f699b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of entries: 1100\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "json_file = Path(\"..\", \"01_main-chapter-code\", \"instruction-data.json\")\n",
    "\n",
    "with open(json_file, \"r\") as file:\n",
    "    json_data = json.load(file)\n",
    "\n",
    "print(\"Number of entries:\", len(json_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6c9751b-59b7-43fe-acc7-14e8daf2fa66",
   "metadata": {},
   "source": [
    "- **该文件的结构如下**，其中：\n",
    "  - `'output'`：测试数据集中提供的 **预期响应**，即模型通过 **指令微调（Instruction Finetuning）** 训练后应生成的内容。  \n",
    "  - `'input'` 和 `'instruction'`：用于指导模型生成 `'output'` 的 **输入数据**。  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7222fdc0-5684-4f2b-b741-3e341851359e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'instruction': 'Evaluate the following phrase by transforming it into the spelling given.',\n",
       " 'input': 'freind --> friend',\n",
       " 'output': 'The spelling of the given phrase \"freind\" is incorrect, the correct spelling is \"friend\".'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "json_data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcf0331b-6024-4bba-89a9-a088b14a1046",
   "metadata": {},
   "source": [
    "- 下面是一个 **小型工具函数**，用于格式化 **指令（instruction）和输入（input）**：  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "43263cd3-e5fb-4ab5-871e-3ad6e7d21a8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_input(entry):\n",
    "    instruction_text = (\n",
    "        f\"Below is an instruction that describes a task. Write a response that \"\n",
    "        f\"appropriately completes the request.\"\n",
    "        f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
    "    )\n",
    "\n",
    "    input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
    "    instruction_text + input_text\n",
    "\n",
    "    return instruction_text + input_text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39a55283-7d51-4136-ba60-f799d49f4098",
   "metadata": {},
   "source": [
    "- 现在，我们使用 **Ollama API** 生成 **`'chosen'`（偏好）** 和 **`'rejected'`（非偏好）** 响应，  \n",
    "  以进行 **模型的偏好微调（Preference Tuning）**。  \n",
    "- **为了直观演示**，这里生成的回答在 **礼貌程度** 上存在 **明显差异**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "735cc089-d127-480a-b39d-0782581f0c41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Dataset response:\n",
      ">> The spelling of the given phrase \"freind\" is incorrect, the correct spelling is \"friend\".\n",
      "\n",
      "impolite response:\n",
      ">> The spelling of the given phrase \"freind\" is flat out wrong, get it together, the correct spelling is \"friend\".\n",
      "\n",
      "Dataset response:\n",
      ">> He goes to the park every day.\n",
      "\n",
      "polite response:\n",
      ">> He goes to the park daily, if I'm not mistaken.\n",
      "\n",
      "Dataset response:\n",
      ">> 45 kilometers is 45000 meters.\n",
      "\n",
      "polite response:\n",
      ">> 45 kilometers is equivalent to 45000 meters.\n",
      "\n",
      "Dataset response:\n",
      ">> Although it was raining, they went for a walk.\n",
      "\n",
      "polite response:\n",
      ">> Although it was raining outside, they still decided to go for a walk.\n",
      "\n",
      "Dataset response:\n",
      ">> 1, 4, 9, 16, 25, 36, 49, 64, 81, 100.\n",
      "\n",
      "impolite response:\n",
      ">> Here are your precious square numbers: 1, 4, 9, 16, 25, 36, 49, 64, 81, 100.\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "for entry in json_data[:5]:\n",
    "    \n",
    "    politeness = random.choice([\"polite\", \"impolite\"])    \n",
    "    prompt = (\n",
    "        f\"Given the input `{format_input(entry)}` \"\n",
    "        f\"and correct output `{entry['output']}`, \"\n",
    "        f\"slightly rewrite the output to be more {politeness}.\"\n",
    "        \"Keep the modification minimal.\"\n",
    "        \"Only return return the generated response and nothing else.\"\n",
    "    )\n",
    "    print(\"\\nDataset response:\")\n",
    "    print(\">>\", entry['output'])\n",
    "    print(f\"\\n{politeness} response:\")\n",
    "    print(\">>\", query_model(prompt))    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "142dfaa7-429f-4eb0-b74d-ff327f79547a",
   "metadata": {},
   "source": [
    "- **如果我们认为上面生成的响应较为合理**，可以进入 **下一步**，将该提示（prompt）应用于 **整个数据集**。  \n",
    "- **在数据集中添加**：\n",
    "  - **`'chosen'`**：代表 **偏好（preferred）响应**  \n",
    "  - **`'rejected'`**：代表 **非偏好（dispreferred）响应**  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3349dbbc-963f-4af3-9790-12dbfdca63c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from tqdm import tqdm\n",
    "\n",
    "def generate_model_responses(json_data):\n",
    "\n",
    "    for i, entry in enumerate(tqdm(json_data, desc=\"Writing entries\")):\n",
    "        politeness = random.choice([\"polite\", \"impolite\"])    \n",
    "        prompt = (\n",
    "            f\"Given the input `{format_input(entry)}` \"\n",
    "            f\"and correct output `{entry['output']}`, \"\n",
    "            f\"slightly rewrite the output to be more {politeness}.\"\n",
    "            \"Keep the modification minimal.\"\n",
    "            \"Only return return the generated response and nothing else.\"\n",
    "        )\n",
    "        response = query_model(prompt)\n",
    "        \n",
    "        if politeness == \"polite\":\n",
    "            json_data[i][\"chosen\"] = response\n",
    "            json_data[i][\"rejected\"] = entry[\"output\"]\n",
    "        else:\n",
    "            json_data[i][\"rejected\"] = response\n",
    "            json_data[i][\"chosen\"] = entry[\"output\"]    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b071ce84-1866-427f-a272-b46700f364b2",
   "metadata": {},
   "source": [
    "- 现在，我们对 **整个数据集** 进行评估，并计算 **每个模型的平均分**（在 **M3 MacBook Air** 上运行 **每个模型约需 1 分钟**）。  \n",
    "- **请注意**，截至目前，Ollama **在不同操作系统上的推理结果并非完全确定性**，  \n",
    "  因此，您的评估分数可能会与下方示例结果 **略有不同**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4f700d4b-19e5-4404-afa7-b0f093024232",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Writing entries: 100%|██████████| 1100/1100 [17:20<00:00,  1.06it/s]\n"
     ]
    }
   ],
   "source": [
    "generate_model_responses(json_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "838d9747-0f7d-46fe-aab5-9ee6b765d021",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"instruction-data-with-preference.json\", \"w\") as file:\n",
    "    json.dump(json_data, file, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
